package com.yangda.example.dao;

import com.yangda.example.framework.model.Pagination;
import com.yangda.example.framework.model.PageRequest;
import com.yangda.example.framework.utils.BeanTransformerAdapter;
import com.yangda.example.framework.utils.Reflections;
import com.yangda.example.interfaces.IBaseDao;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.*;
import org.hibernate.criterion.*;
import org.hibernate.internal.CriteriaImpl;
import org.hibernate.metadata.ClassMetadata;
import org.hibernate.transform.ResultTransformer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.Assert;

import java.io.Serializable;
import java.util.*;

/**
 * @文件名 :BaseDao.java
 * @param <T>  数据库实体
 * @param <PK> 主键
 */
public abstract class BaseDao<T, PK extends Serializable> implements IBaseDao<T, PK>{


	protected Logger logger = LoggerFactory.getLogger(getClass());

	@Autowired
	protected SessionFactory sessionFactory;

	protected Class<T> entityClass;

	public BaseDao() {
		this.entityClass = Reflections.getClassGenricType(getClass());
	}

	public BaseDao(final SessionFactory sessionFactory, final Class<T> entityClass) {
		this.sessionFactory = sessionFactory;
		this.entityClass = entityClass;
	}

	public Session getSession() {
	       return sessionFactory.getCurrentSession();
	}

	@Override
	public String getIdName() {
		ClassMetadata meta = sessionFactory.getClassMetadata(entityClass);
		return meta.getIdentifierPropertyName();
	}

	@Override
	public void save(T entity) {
		getSession().saveOrUpdate(entity);
	}

	@Override
	public void delete(T entity) {
		getSession().delete(entity);
	}

	@Override
	public void delete(PK id) {
		delete(find(id));
	}

	@Override
	public void batchDelete(PK[] ids) {
		StringBuilder sb = new StringBuilder();
		sb.append("(");
		for (PK id : ids) {
			sb.append(id + ",");
		}
		sb.deleteCharAt(sb.length()-1);
		sb.append(")");
		batchExecute("delete from "+entityClass.getSimpleName()+" where "+getIdName()+" in "+sb);
	}

	@Override
	public int batchExecute(String hql, Object... values) {
		return createQuery(hql, values).executeUpdate();
	}

	@Override
	public int batchExecute(String hql, Map values) {
		return createQuery(hql, values).executeUpdate();
	}

	@Override
	public T find(PK id) {
		return (T) getSession().get(entityClass, id);
	}

	@Override
	public <X> X findUniqueBy(String propertyName, Object value) {
		Criterion criterion = Restrictions.eq(propertyName, value);
		return (X) createCriteria(criterion).uniqueResult();
	}

	@Override
	public <X> X  findUnique(String hql, Map values) {
		return (X) createQuery(hql, values).uniqueResult();
	}

	@Override
	public <X> X findUnique(Criterion... criterions) {
		return (X) createCriteria(criterions).uniqueResult();
	}

	@Override
	public List<T> find(Collection<PK> idList) {
		return find(Restrictions.in(getIdName(), idList));
	}

	@Override
	public List<T> findAll() {
		return find();
	}

	@Override
	public List<T> findAll(String orderByProperty, boolean isAsc) {
		Criteria c = createCriteria();
		if (isAsc) {
			c.addOrder(Order.asc(orderByProperty));
		} else {
			c.addOrder(Order.desc(orderByProperty));
		}
		return c.list();
	}

	@Override
	public List<T> find(String hql, Map values) {
		return createQuery(hql, values).list();
	}

	@Override
	public List<T> find(String propertyName, Object value) {
		Criterion criterion = Restrictions.eq(propertyName, value);
		return find(criterion);
	}

	@Override
	public List<T> find(Criterion... criterions) {
		return createCriteria(criterions).list();
	}

	@Override
	public List query(String sql, Map values, Class clazz) {
		Query q = createSQLQuery(sql, values);
		if (clazz != null) {
			q.setResultTransformer(new BeanTransformerAdapter<>(clazz));
		}
		return q.list();
	}

	@Override
	public Object queryUnique(String sql, Map values, Class clazz) {
		Query q = createSQLQuery(sql, values);
		if (clazz != null) {
			q.setResultTransformer(new BeanTransformerAdapter<>(clazz));
		}
		List list = q.list();
		if(list!=null && list.size()>0){
			return list.get(0);
		}else{
			return null;
		}
	}

	private Pagination findPage(final String hql, Integer pageIndex, Integer pageSize, final Map values, Class clazz) {
		Query q = createQuery(hql, values);
		if (clazz != null) {
			q.setResultTransformer(new BeanTransformerAdapter<>(clazz));
		}
		int totalCount = countHqlResult(hql, values);
		setPageParameterToQuery(q, pageIndex,pageSize);
		List result = q.list();
		return getPageResult(pageIndex, pageSize, totalCount, result);
	}

	private Pagination findPage(Integer pageIndex, Integer pageSize, String orderBy, String order, final Criterion... criterions) {
		Criteria c = createCriteria(criterions);
		int totalCount = countCriteriaResult(c);
		setPageParameterToCriteria(c, pageIndex,pageSize,order,orderBy);
		List result = c.list();
		return getPageResult(pageIndex, pageSize, totalCount, result);
	}

	private Pagination query(final String sql, Integer pageIndex, Integer pageSize, final Map  values, Class clazz){
		Query q = createSQLQuery(sql, values);
		if (clazz != null) {
			q.setResultTransformer(new BeanTransformerAdapter(clazz));
		}
		int totalCount = getTotalCount(countSql(sql), values);
		setPageParameterToQuery(q, pageIndex,pageSize);
		List result = q.list();
		return getPageResult(pageIndex, pageSize, totalCount, result);
	}

	@Override
	public Pagination findPage(String hql, PageRequest pageRequest, Map values, Class clazz) {
		return findPage(hql, pageRequest.getCurrentPage(), pageRequest.getPageSize(),values,clazz);
	}

	@Override
	public Pagination<T> findPage(String hql, PageRequest pageRequest, Map values) {
		return findPage(hql, pageRequest.getCurrentPage(), pageRequest.getPageSize(),values, null);
	}

	@Override
	public Pagination<T> findPage(PageRequest pageRequest, Criterion... criterions) {
		return findPage(pageRequest.getCurrentPage(), pageRequest.getPageSize(), pageRequest.getOrderBy(), pageRequest.getOrder(), criterions);
	}

	@Override
	public Pagination query(String sql, PageRequest pageRequest, Map values, Class clazz) {
		return query(sql, pageRequest.getCurrentPage(), pageRequest.getPageSize(), values, clazz);
	}

	private int countHqlResult(String hql, Map values) {
		String countHql = prepareCountHql(hql);

		try {
			Long count = findUnique(countHql, values);
			return Integer.parseInt(String.valueOf( count));
		} catch (Exception e) {
			throw new RuntimeException("hql can't be auto count, hql is:" + countHql, e);
		}
	}

	private String countSql(String sql){
		return "SELECT COUNT(*) FROM ( "+sql+" ) temp";
	}

	private int countCriteriaResult(final Criteria c) {
		CriteriaImpl impl = (CriteriaImpl) c;

		// 先把Projection、ResultTransformer、OrderBy取出来,清空三者后再执行Count操作
		Projection projection = impl.getProjection();
		ResultTransformer transformer = impl.getResultTransformer();

		List<CriteriaImpl.OrderEntry> orderEntries = null;
		try {
			orderEntries = (List) Reflections.getFieldValue(impl, "orderEntries");
			Reflections.setFieldValue(impl, "orderEntries", new ArrayList());
		} catch (Exception e) {
			logger.error("不可能抛出的异常:{}", e.getMessage());
		}

		// 执行Count查询
		Long totalCountObject = (Long) c.setProjection(Projections.rowCount()).uniqueResult();
		long totalCount = (totalCountObject != null) ? totalCountObject : 0;

		// 将之前的Projection,ResultTransformer和OrderBy条件重新设回去
		c.setProjection(projection);

		if (projection == null) {
			c.setResultTransformer(CriteriaSpecification.ROOT_ENTITY);
		}
		if (transformer != null) {
			c.setResultTransformer(transformer);
		}
		try {
			Reflections.setFieldValue(impl, "orderEntries", orderEntries);
		} catch (Exception e) {
			logger.error("不可能抛出的异常:{}", e.getMessage());
		}

		return Integer.parseInt(String.valueOf( totalCount));
	}

	public Query createQuery(final String queryString, final Object... values) {
		Query query = getSession().createQuery(queryString);
		if (values != null) {
			for (int i = 0; i < values.length; i++) {
				query.setParameter(i, values[i]);
			}
		}
		return query;
	}

	public Query createQuery(final String queryString, final Map<String, ?> values) {
		Query query = getSession().createQuery(queryString);
		if (values != null) {
			query.setProperties(values);
		}
		return query;
	}

	public SQLQuery createSQLQuery(final String queryString, final Map values) {
		SQLQuery sqlQuery = getSession().createSQLQuery(queryString);
		if (values != null) {
			sqlQuery.setProperties(values);
		}
		return sqlQuery;
	}

	public Criteria createCriteria(final Criterion... criterions) {
		Criteria criteria = getSession().createCriteria(entityClass);
		for (Criterion c : criterions) {
			criteria.add(c);
		}
		return criteria;
	}

	private Query setPageParameterToQuery(final Query q, Integer pageIndex,Integer pageSize) {
		Assert.isTrue(pageSize > 0, "Page Size must larger than zero");
		//hibernate的firstResult的序号从0开始
		q.setFirstResult((pageIndex - 1) * pageSize);
		q.setMaxResults(pageSize);
		return q;
	}

	private Criteria setPageParameterToCriteria(final Criteria c , Integer pageIndex,Integer pageSize,String order,String orderBy) {

		Assert.isTrue(pageSize > 0, "Page Size must larger than zero");

		//hibernate的firstResult的序号从0开始
		c.setFirstResult((pageIndex - 1) * pageSize);
		c.setMaxResults(pageSize);

		if (orderBy != null) {
			String[] orderByArray = StringUtils.split(orderBy, ',');
			String[] orderArray = StringUtils.split(order, ',');

			Assert.isTrue(orderByArray.length == orderArray.length, "分页多重排序参数中,排序字段与排序方向的个数不相等");

			for (int i = 0; i < orderByArray.length; i++) {
				if (PageRequest.ASC.equals(orderArray[i])) {
					c.addOrder(Order.asc(orderByArray[i]));
				} else {
					c.addOrder(Order.desc(orderByArray[i]));
				}
			}
		}
		return c;
	}

	private Pagination getPageResult(Integer pageIndex, Integer pageSize, int totalCount, List result) {
		// 计算分页总页数
		int totalPage = pageSize == null || pageSize<=0 ? 0 : (totalCount - 1) / pageSize + 1;
		boolean hasPrevPage = pageIndex ==null || pageIndex<=0 || pageIndex == 1 ? false : true;
		boolean hasNextPage = pageIndex ==null || pageIndex<=0 || pageIndex == totalPage || totalPage == 0 ? false : true;
		return  new Pagination(totalCount, totalPage, pageSize, pageIndex, hasPrevPage, hasNextPage, result);
	}

	private String prepareCountHql(String orgHql) {
		String fromHql = orgHql;
		//select子句与order by子句会影响count查询,进行简单的排除.
		fromHql = "from " + StringUtils.substringAfter(fromHql, "from");
		fromHql = StringUtils.substringBefore(fromHql, "order by");

		String countHql = "select count(*) " + fromHql;
		return countHql;
	}

	private Integer getTotalCount(String sql, Map params) {
		Query query = this.getSession().createSQLQuery(sql);
		if(params != null && params.size()>0){
			Iterator it = params.keySet().iterator();
			while (it.hasNext()) {
				Object key = it.next();
				query.setParameter(key.toString(), params.get(key));
			}
		}
		return ((Number) query.uniqueResult()).intValue();
	}
}
